package com.nstskj.study.netty.tcp.netty.protocol.message.in.factory.impl;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.ReflectUtil;
import com.nstskj.study.netty.tcp.netty.protocol.constants.ProtocolCmdConstants;
import com.nstskj.study.netty.tcp.netty.protocol.definition.ProtocolMsgBeanDefinition;
import com.nstskj.study.netty.tcp.netty.protocol.definition.ProtocolMsgFieldDefinition;
import com.nstskj.study.netty.tcp.netty.protocol.definition.strategy.ProtocolMsgBeanDefinitionStrategy;
import com.nstskj.study.netty.tcp.netty.protocol.message.in.msg.base.AbstractInputNetTcpMessage;
import com.nstskj.study.netty.tcp.netty.protocol.message.in.factory.NetMessageDecoderFactory;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.CodecException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

/**
 * @author ZhouChuGang
 * @version 1.0
 * @project nstskj-study-netty-tcp-spring
 * @date 2021/4/21 13:27
 * @Description
 */
@Slf4j
@Component
public class SpringNetMessageDecoderFactory implements NetMessageDecoderFactory {

    @Autowired
    private ProtocolMsgBeanDefinitionStrategy protocolMsgBeanDefinitionStrategy;

    /**
     * 解析报文，得到对应的消息bean
     *
     * @param byteBuf
     * @return
     */
    @Override
    public AbstractInputNetTcpMessage parse(ByteBuf byteBuf) {
        //如果小于最小长度直接跳过
        if (byteBuf.readableBytes() < ProtocolCmdConstants.PROTOCOL_MIN_LEN) {
            throw new CodecException("命令长度错误");
        }
        //先得到起始头
        short head = byteBuf.readShort();
        if (ProtocolCmdConstants.PROTOCOL_HEAD != head) {
            //如果起始头不对 则直接跳过2个字节
            throw new CodecException("命令头错误");
        }
        //得到命令长度
        int length = byteBuf.readInt();
        byte[] versionBytes = new byte[2];
        //版本号
        byteBuf.readBytes(versionBytes);
        String version = new String(versionBytes, StandardCharsets.UTF_8);
        //命令码
        short cmd = byteBuf.readShort();
        //序列号
        int serial = byteBuf.readInt();
        //得到对应cmd bean的定义
        ProtocolMsgBeanDefinition protocolMsgBeanDefinition = protocolMsgBeanDefinitionStrategy.getProtocolMsgBeanDefinition(cmd);
        if (Objects.isNull(protocolMsgBeanDefinition)) {
            throw new CodecException("命令不支持 cmd : " + cmd);
        }
        //创建bean对象
        AbstractInputNetTcpMessage inputNetTcpMessage = ReflectUtil.newInstance(protocolMsgBeanDefinition.getAbstractInputNetTcpMessageClass());
        inputNetTcpMessage.setCmd(cmd);
        inputNetTcpMessage.setHead(head);
        inputNetTcpMessage.setSerial(serial);
        inputNetTcpMessage.setVersion(version);
        inputNetTcpMessage.setDataLength(length);
        //得到命令数据内容
        int cmdDataLen = length - 8;
        if ((protocolMsgBeanDefinition.getCmdDataMinLen() != -1) && (cmdDataLen < protocolMsgBeanDefinition.getCmdDataMinLen())) {
            throw new CodecException("命令长度错误 cmd : " + cmd);
        }
        byte[] cmdData = new byte[cmdDataLen];
        if (cmdDataLen > 0) {
            //读取命令数据
            byteBuf.readBytes(cmdData);
        }
        inputNetTcpMessage.setCmdData(cmdData);
        //这里开始解析字段定义
        List<ProtocolMsgFieldDefinition> fieldDefinitions = protocolMsgBeanDefinition.getFieldDefinitions();
        if (CollUtil.isNotEmpty(fieldDefinitions) && cmdDataLen > 0) {
            for (ProtocolMsgFieldDefinition fieldDefinition : fieldDefinitions) {
                //填充消息字段值
                setProtocolMsgFieldValue(inputNetTcpMessage, fieldDefinition);
            }
        }
        try {
            //自定义报文解析
            inputNetTcpMessage.decoderCmdData();
        } catch (Exception e) {
            log.error("执行自定义解析错误", e);
            throw new CodecException("协议解析错误");
        }
        return inputNetTcpMessage;
    }

    /**
     * 填充消息字段值
     *
     * @param inputNetTcpMessage 抽象消息对象
     * @param fieldDefinition    字段定义描述
     */
    private void setProtocolMsgFieldValue(AbstractInputNetTcpMessage inputNetTcpMessage, ProtocolMsgFieldDefinition fieldDefinition) {
        byte[] cmdData = inputNetTcpMessage.getCmdData();
        if (cmdData == null || cmdData.length < fieldDefinition.getFieldStartIndex()) {
            throw new CodecException("命令长度错误 cmd : " + inputNetTcpMessage.getCmd());
        }
        byte[] bytes = new byte[fieldDefinition.getFieldLen()];
        ArrayUtil.copy(cmdData, fieldDefinition.getFieldStartIndex(), bytes, 0, bytes.length);
        //根据不同的字段类型来处理
        switch (fieldDefinition.getFieldTypeEnums()) {
            case STRING_TYPE:
                String valueStr = new String(bytes, StandardCharsets.UTF_8);
                ReflectUtil.setFieldValue(inputNetTcpMessage, fieldDefinition.getFieldName(), valueStr);
                break;
            case BYTE_TYPE:
                ReflectUtil.setFieldValue(inputNetTcpMessage, fieldDefinition.getFieldName(), bytes[0]);
                break;
            case SHORT_TYPE:
                short valueShort = (short) (bytes[0] * 256 + bytes[1]);
                ReflectUtil.setFieldValue(inputNetTcpMessage, fieldDefinition.getFieldName(), valueShort);
                break;
            case INTEGER_TYPE:
                int valueInt = (bytes[0] << 24) |
                        (bytes[1] << 16) |
                        (bytes[2] << 8) |
                        bytes[3];
                ReflectUtil.setFieldValue(inputNetTcpMessage, fieldDefinition.getFieldName(), valueInt);
                break;
            case LONG_TYPE:
                long valueLong = ((long) bytes[0] << 56) |
                        ((long) bytes[1] << 48) |
                        ((long) bytes[2] << 40) |
                        ((long) bytes[3] << 32) |
                        ((long) bytes[4] << 24) |
                        ((long) bytes[5] << 16) |
                        ((long) bytes[6] << 8) |
                        ((long) bytes[7]);
                ReflectUtil.setFieldValue(inputNetTcpMessage, fieldDefinition.getFieldName(), valueLong);
                break;
            case BYTE_ARRAY_TYPE:
                ReflectUtil.setFieldValue(inputNetTcpMessage, fieldDefinition.getFieldName(), bytes);
                break;
            default:
                throw new CodecException("不支持的字段类型 " + fieldDefinition.getFieldTypeEnums().toString());
        }
    }

}
